{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-h05_PNNhZ-D"
   },
   "source": [
    "# Working with Pytrees\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05.1-pytrees.ipynb)\n",
    "\n",
    "*Author: Vladimir Mikulik*\n",
    "\n",
    "Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n",
    "\n",
    "JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9UjxVY9ulSCn"
   },
   "source": [
    "## What is a pytree?\n",
    "\n",
    "As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n",
    "\n",
    "> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n",
    "\n",
    "Some example pytrees:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Wh6BApZ9lrR1",
    "outputId": "37b8d89c-8dd0-4f2b-f479-8333f4b3a2c3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 'a', <object object at 0x7fded60bb8c0>]   has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]\n",
      "(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]\n",
      "[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]\n",
      "{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]\n",
      "DeviceArray([1, 2, 3], dtype=int32)           has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "example_trees = [\n",
    "    [1, 'a', object()],\n",
    "    (1, (2, 3), ()),\n",
    "    [1, {'k1': 2, 'k2': (3, 4)}, 5],\n",
    "    {'a': 2, 'b': (2, 3)},\n",
    "    jnp.array([1, 2, 3]),\n",
    "]\n",
    "\n",
    "# Let's see how many leaves they have:\n",
    "for pytree in example_trees:\n",
    "  leaves = jax.tree_leaves(pytree)\n",
    "  print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_tWkkGNwW8vf"
   },
   "source": [
    "We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RcsmneIGlltm"
   },
   "source": [
    "## Why pytrees?\n",
    "\n",
    "In machine learning, some places where you commonly find pytrees are:\n",
    "* Model parameters\n",
    "* Dataset entries\n",
    "* RL agent observations\n",
    "\n",
    "They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sMrSGSIJn9MD"
   },
   "source": [
    "## Common pytree functions\n",
    "The most commonly used pytree functions are `jax.tree_map` and `jax.tree_multimap`. They work analogously to Python's native `map`, but on entire pytrees.\n",
    "\n",
    "For functions with one argument, use `jax.tree_map`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "wZRcuQu4n7o5",
    "outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_of_lists = [\n",
    "    [1, 2, 3],\n",
    "    [1, 2],\n",
    "    [1, 2, 3, 4]\n",
    "]\n",
    "\n",
    "jax.tree_map(lambda x: x*2, list_of_lists)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xu8X3fk4orC9"
   },
   "source": [
    "To use functions with more than one argument, use `jax.tree_multimap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "KVpB4r1OkeUK",
    "outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
      ]
     },
     "execution_count": 3,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "another_list_of_lists = list_of_lists\n",
    "jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dkRKy3LvowAb"
   },
   "source": [
    "For `tree_multimap`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lla4hDW6sgMZ"
   },
   "source": [
    "## Example: ML model parameters\n",
    "\n",
    "A simple example of training an MLP displays some ways in which pytree operations come in useful:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "j2ZUzWx8tKB2"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def init_mlp_params(layer_widths):\n",
    "  params = []\n",
    "  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n",
    "    params.append(\n",
    "        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n",
    "             biases=np.ones(shape=(n_out,))\n",
    "            )\n",
    "    )\n",
    "  return params\n",
    "\n",
    "params = init_mlp_params([1, 128, 128, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kUFwJOspuGvU"
   },
   "source": [
    "We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "ErWsXuxXse-z",
    "outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'biases': (128,), 'weights': (1, 128)},\n",
       " {'biases': (128,), 'weights': (128, 128)},\n",
       " {'biases': (1,), 'weights': (128, 1)}]"
      ]
     },
     "execution_count": 5,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zQtRKaj4ua6-"
   },
   "source": [
    "Now, let's train our MLP:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "iL4GvW9OuZ-X"
   },
   "outputs": [],
   "source": [
    "def forward(params, x):\n",
    "  *hidden, last = params\n",
    "  for layer in hidden:\n",
    "    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n",
    "  return x @ last['weights'] + last['biases']\n",
    "\n",
    "def loss_fn(params, x, y):\n",
    "  return jnp.mean((forward(params, x) - y) ** 2)\n",
    "\n",
    "LEARNING_RATE = 0.0001\n",
    "\n",
    "@jax.jit\n",
    "def update(params, x, y):\n",
    "\n",
    "  grads = jax.grad(loss_fn)(params, x, y)\n",
    "  # Note that `grads` is a pytree with the same structure as `params`.\n",
    "  # `jax.grad` is one of the many JAX functions that has\n",
    "  # built-in support for pytrees.\n",
    "\n",
    "  # This is handy, because we can apply the SGD update using tree utils:\n",
    "  return jax.tree_multimap(\n",
    "      lambda p, g: p - LEARNING_RATE * g, params, grads\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "B3HniT9-xohz",
    "outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3TU9Z3/8ec7w0QDKFFBrQkUTn8urUBIfkbUBS8LRawKjfwsar1fSn+nVkRbCqweCa0rVM7xQtfub63Xrjcixohal/XOYtWabDCISM/Wa6JbAQ1WiWaSfH5/TCZkwlzJXL4z83qcwwnz/X6ZvCckr3znczXnHCIi4l1F2S5ARERiU1CLiHicglpExOMU1CIiHqegFhHxuCHpeNKRI0e6sWPHpuOpRUTyUlNT0w7n3KhI59IS1GPHjqWxsTEdTy0ikpfM7P1o59T0ISLicQpqERGPU1CLiHhcWtqoIwkEArS2tvLVV19l6lNKBuy///6Ul5fj9/uzXYpI3spYULe2tnLAAQcwduxYzCxTn1bSyDnHzp07aW1tZdy4cdkuRyRvZSyov/rqK4V0njEzDjnkELZv357tUkSyqqG5jVXrt/FRewdHlJawaNZ4aqrKUvb8GQtqQCGdh/R/KoWuobmNpfWb6Qh0A9DW3sHS+s0AKQtrdSaKiAzCqvXb+kI6pCPQzar121L2OeIGtZmNN7NN/f58bmYLU1ZBBpkZ559/ft/jrq4uRo0axRlnnJHU84wdO5YdO3YM+prBeu+995g4cSIAjY2NLFiwIOb1N954Y9jjv//7v09bbSKF4qP2jqSO74u4Qe2c2+acq3TOVQJHA7uBx1JWQQYNGzaMN998k46O4BfwmWeeoawsde1IqdLV1ZX0v6murmb16tUxrxkY1H/84x+T/jwiEu6I0pKkju+LZJs+ZgB/cc5FneqYMi11cMtEqC0NfmypS8nTnnbaaTz11FMAPPTQQ5x77rl95z799FNqamqoqKjguOOOo6WlBYCdO3dyyimnMGHCBC6//HL674pz//33M2XKFCorK/nxj39Md3f4W6CBhg8fztVXX82ECROYMWNGX0fcySefzMKFC6murua2226jqamJk046iaOPPppZs2bx8ccfA9DU1MTkyZOZPHkyt99+e9/zvvjii33vDL744gsuueQSJk2aREVFBY8++ihLliyho6ODyspKzjvvvL5aIDh6Y9GiRUycOJFJkyaxZs2avuc8+eSTOeuss/j2t7/Neeedh3YEEgm3aNZ4Svy+sGMlfh+LZo1P2edINqjPAR6KdMLM5ptZo5k1DnoUQEsdPLEAdn0IuODHJxakJKzPOeccHn74Yb766itaWlo49thj+84tW7aMqqoqWlpauPHGG7nwwgsBWL58OdOmTWPLli2ceeaZfPDBBwBs3bqVNWvW8PLLL7Np0yZ8Ph8PPPBAzM//5ZdfUl1dzZYtWzjppJNYvnx537nOzs6+Jowrr7yStWvX0tTUxKWXXsq1114LwCWXXMJvfvMb3njjjaif41e/+hUjRoxg8+bNtLS0MH36dFauXElJSQmbNm3aq8b6+no2bdrEG2+8wbPPPsuiRYv6fjE0Nzdz66238tZbb/HOO+/w8ssvJ/HVFsl/NVVlrJg7ibLSEgwoKy1hxdxJ2Rn1YWbFwBxgaaTzzrk7gDsAqqurB3fb9dwvITCgfSfQETxeMW9QT11RUcF7773HQw89xGmnnRZ2buPGjTz66KMATJ8+nZ07d/L555+zYcMG6uvrATj99NM56KCDgmU+9xxNTU0cc8wxAHR0dHDooYfG/PxFRUWcffbZAJx//vnMnTu371zo+LZt23jzzTeZOXMmAN3d3XzjG9+gvb2d9vZ2TjzxRAAuuOACnn766b0+x7PPPsvDDz/c9zhUbzQbN27k3HPPxefzcdhhh3HSSSfx+uuvc+CBBzJlyhTKy8sBqKys5L333mPatGkxn0+k0NRUlaU0mAdKZnje94D/cs79NV3F9NnVmtzxJM2ZM4ef//znvPjii+zcuXOfn8c5x0UXXcSKFSv2+Tn6D28bNmxY3/NOmDCBV155Jeza9vb2ff48+2q//fbr+7vP59un9nMRGZxkmj7OJUqzR8qNKE/ueJIuvfRSli1bxqRJk8KOn3DCCX3NAi+++CIjR47kwAMP5MQTT+TBBx8E4Omnn+azzz4DYMaMGaxdu5ZPPvkECLZxv/9+7Ob7np4e1q5dC8CDDz4Y8e50/PjxbN++vS+oA4EAW7ZsobS0lNLSUjZu3AgQtZll5syZYe3XoXr9fj+BQGCv60844QTWrFlDd3c327dvZ8OGDUyZMiXm6xCRzEkoqM1sGDATqE9vOb1mXA/+AT2m/pLg8RQoLy+POJSttraWpqYmKioqWLJkCffddx8QbLvesGEDEyZMoL6+njFjxgBw1FFHccMNN3DKKadQUVHBzJkz+9p2oxk2bBh/+tOfmDhxIs8//zzXX7/3ayouLmbt2rUsXryYyZMnU1lZ2TdC45577uGKK66gsrIyasfeddddx2effcbEiROZPHkyL7zwAgDz58+noqKirzMx5Mwzz6SiooLJkyczffp0brrpJg4//PA4X0URyRRLRy9+dXW1G7hxwNatW/nOd76T+JO01AXbpHe1Bu+kZ1w/6PZpLxg+fDhffPFFtstIqaT/b0VkL2bW5JyrjnQuo1PIk1IxLy+CWURksDSFPMPy7W5aRNIvo0GtyRL5R/+nIumXsaDef//92blzp36w80hoPer9998/26WI5LWMtVGXl5fT2tqqtYvzTGiHFxFJn4wFtd/v1y4gIiL7wLujPkREPCTdu7jEoqAWEYkjE7u4xKLheSIicWRiF5dYFNQiInFkYheXWBTUIiJxZGIXl1gU1CIicWRiF5dY1JkoIhJHqMNQoz5ERDws3bu4xKKmDxERj1NQi4h4nIJaRMTjFNQiIh6X6J6JpWa21szeNrOtZnZ8ugsTEZGgREd93Ab8u3PuLDMrBoamupBsLngiIuJlcYPazEYAJwIXAzjnOoHOVBaR7QVPRES8LJGmj3HAduAeM2s2szvNbNjAi8xsvpk1mlljspsDZHvBExERL0skqIcA/xv4F+dcFfAlsGTgRc65O5xz1c656lGjRiVVRLYXPBER8bJEgroVaHXOvdb7eC3B4E6ZbC94IiLiZXGD2jn3P8CHZhZafWQG8FYqi8j2giciIl6W6KiPK4EHekd8vANcksoisr3giYiIlyUU1M65TUB1OgvJ5oInIiJeppmJIiIep6AWEfE4BbWIiMcpqEVEPE5BLSIyWC11cMtEqC0NfmypS+nTaysuEZHBaKmDJxZAoHcm9a4Pg48BKual5FPojlpEZDCe++WekA4JdASPp4iCWkRkMHa1Jnd8HyioRUQGY0R5csf3gYJaRGQwZlwP/gELyPlLgsdTREEtIjIYFfNg9moYMRqw4MfZq1PWkQga9SEikpiWumAH4a7WYLPGjOv3hHHFvJQG80AKahGReDIwBC8WNX2IiMSTgSF4sSioRUTiycAQvFgU1CIi8WRgCF4sCmoRkXgyMAQvFgW1iEg8GRiCF0tCoz7M7D3gb0A30OWcS+u2XCIinpPmIXixJDM87x+cczvSVomIiESkpg8REY9LNKgd8B9m1mRm89NZkIiIhEu06WOac67NzA4FnjGzt51zG/pf0Bvg8wHGjBmT4jJFRApXQnfUzrm23o+fAI8BUyJcc4dzrto5Vz1q1KjUVikiUsDiBrWZDTOzA0J/B04B3kx3YSIiEpRI08dhwGNmFrr+Qefcv6e1KhGRfdTQ3Maq9dv4qL2DI0pLWDRrPDVVZdkua1DiBrVz7h1gcgZqEREZlIbmNhY98gaBHgdAW3sHix55AyCnw1rD80Qkb9Su29IX0iGBHkftui1Zqig1FNQikjfaOwJJHc8VCmoREY9TUItI3jhoqD+p47lCQS0ieWPZ7An4fRZ2zO8zls2ekKWKUkN7JopI3giN7Ci44XkiIrmkpqos54N5IAW1iOSXJ6+BpnvBdYP54OiL4Yybs13VoCioRSR/PHkNNN6157Hr3vM4h8NanYkikj+a7k3ueI5QUItI/nDdyR3PEQpqEckf5kvueI5QUItI/jj64uSO5wh1JopIzom6lGmow1CjPkREsqehuY2l9ZvpCATbndvaO1havxlgT1jneDAPpKYPEckpq9ZvY2b3S2wsXsA7+/2QjcULmNn9EqvWb8t2aWmjoBaRnFL9+TOs9N9JedEOigzKi3aw0n8n1Z8/k+3S0kZBLSI5ZWnxIwy1zrBjQ62TpcWPZKmi9FNQi0hOOYwdSR3PBwkHtZn5zKzZzJ5MZ0EiIrHYiPKkjueDZO6orwK2pqsQEZGEzLge/CXhx/wlweN5KqGgNrNy4HTgzvSWIyISR8U8mL0aRowGLPhx9urg8TyV6DjqW4FfAAdEu8DM5gPzAcaMGTP4ykREoqmYl9fBPFDcO2ozOwP4xDnXFOs659wdzrlq51z1qFGjUlagiEihS6TpYyowx8zeAx4GppvZ/WmtSkRE+sQNaufcUudcuXNuLHAO8Lxz7vy0VyYiIoCXxlG31MEtE6G2NPixpS7bFYmIeEJSizI5514EXkx5FS118MQCCHQEH+/6MPgYCqrDQEQkEm/cUT/3yz0hHRLoCB4XESlw3gjqXa3JHRcRKSDeCOpoUz/zeEqoiEiivBHUBTglVEQkUd4I6gKcEioikijvbMVVYFNCRUQS5Y07ahERiUpBLSLicQpqEUk/zTweFO+0UYtIftLM40HTHbWIpJdmHg+aglpE0kszjwdNQS0i6aWZx4OmoBaR9NLM40FTZ2ICGprbWFrfQkegB4Aigx8eO4YbaiZluTKRHBDqMHzul8HmjhHlwZBWR2LCFNRxNDS3cc2aTfT0O9bj4P5XPwBQWIskQjOPB0VNH3GsWr8tLKT7e+i1DzNai4gUJgV1HB+1d0Q91+1cBisR8TBNaEmruE0fZrY/sAHYr/f6tc65ZekuzCuOKC2hLUpY+8wyXI2IByUwoaWhuY1V67fxUXsHR5SWsGjWeGqqyrJUcO5J5I76a2C6c24yUAmcambHpbcs71g0a3zUL9K5x47OaC0inhRnQkuwM34zbe0dOKCtvYOl9ZtpaG7LfK05Km5Qu6Aveh/6e/8UzHv+mqoybj67khL/ni9VkcH5x2nUhwgtdcE76Eh6J7SsWr+NjkB32KmOQDer1m9Ld3V5I6FRH2bmA5qA/wXc7px7LcI184H5AGPGjElljVlXU1Wmt2kiA4WaPKLpndASrZ8nVv+PhEuoM9E51+2cqwTKgSlmNjHCNXc456qdc9WjRo1KdZ3ZpY4Skb1FavII6Teh5YjSkoiXRDsue0tq1Idzrh14ATg1PeV4UOiuYdeHgAt+rP8R3Dcn25WJZFestTr6baW3aNZ4Svy+sNMlfh+LZo1PZ3V5JW5Qm9koMyvt/XsJMBN4O92FeUa0u4Z3X4Inr8l8PSJeEXUNj9Fhk1tqqspYMXcSZaUlGFBWWsKKuZPUnJiERNqovwHc19tOXQTUOeeeTG9ZHhLrrqHpXjjjZg09ksLRUrdnKnjJQVDkh57AnvNR1vBQP8/gxA1q51wLUJWBWrxpRHn0Xm3X3Tf0KNSrHRp6BOgbU/LLwPHSHZ+CrxhKDoaOz7SGRxpprY94ZlwfbJOOxHysWr+Nmd0v8YviOo6wHXzkRnJT1zxWrS9WUEt+idQM2N0JxcNg8bvZqalAaAp5PBXzYNxJkc8dfTHVnz/DSv+dlBftoMigvGgHK/13Uv35M5mtUyTdtAFA1iioE3HROqi+DKy359p8wcdn3MzS4kcYap1hlw+1Tm4t/q2G8kl+0QYAWaOmj0SdcXPwzwCHsSPi5Qaw60O6Hr8y+EVWu53kuhnXh7dRQ1/noTrU00t31INkce4mhnR/Rc9jP9adteS+innB8dEjRgMW/Dh7NQ3dU7WWR5qZS8NSndXV1a6xsTHlz+tJA3vCo/GXhE0CEMkXU1c+H3GFybLSEl5eMj0LFeUmM2tyzlVHOqc76sHqd5cR83dev9XERPKJ1vJIPwV1KlTMg6vfZLl/IbtdcdTLnHrHJc80NLdRFGVddq3lkToK6hSqPH0+17v5dLnIX9a/MjLDFYmkT2iyV6SdjrSWR2opqFOopqqMaWf+hGsC/3evO+vdrpgVnT8IPtBqfJIHIq0zDcGdj7SWR2opqFOspqqMpgNnsiRwOa09I+lxRmvPSJYELqfxwJmRV+N7YoHCWnJOtDboHucU0immcdRpsGjWeJbWd7Kuc1rfsRK/jxWzxsNzEUaIhDoaNSJEcki0/UTVNp16uqNOg5jLOkadhvuhmkIkp2id6czRHXWaRF3WMdZqfH0bE8yHD16NOBNSJJsGzkD8P0eX8cLb2zUjMc0U1JkWaRruXhw03g1jjlNziHhGpCV9H21qU8dhBqjpI9MGTsONymmCjKTPPow80m7i2aOgzoKG7qlM/Xo14756gP8hxkbAmiAj6bCPI480AzF7FNQZFnr7GFrA5sbOH9ATbep5yUEaby2pF2kDgASWONBu4tmTyOa2o83sBTN7y8y2mNlVmSgsXw18+7iuZxr/1v1degZe6CuGr/+29+7nvx6nwJbB2ccNADTKI3sSuaPuAn7mnDsKOA64wsyOSm9Z+SvS28RlXZdydedPwpePLB4evmloSMenmiAjg7OPGwBoN/HsSWRz24+Bj3v//jcz2wqUAW+luba8FG2SQOOBM+HqFXsO1JZGfxJNkJFkDdw93Fcc3O8wJMru4QNpN/HsSKqN2szGEtyR/LUI5+abWaOZNW7fvj011eWhSG8f/UXG7s4uxi15iqkrnw8uuB5veyN1NEqiWuqg4Sd7mtE6PoXuruDu4f02ANAvfu9KeBy1mQ0HHgUWOuc+H3jeOXcHcAcENw5IWYV5JnQ3Epo0MKLEz5edXXy2O9jM0dbewdVrNtH5d5cwb/eq6OOttU+dJOrpxRGa0Xp7RWrbM16OJC+hO2oz8xMM6Qecc/XpLSn/1VSV8fKS6by78nSG7TeEQHf47zUHLP7zt3l90vLeu54BEnybKgIE76CTOS6ek8ioDwPuArY65zSnOcWijUF1wMK3joTF78Lc34V3NE7+YbC9UcP2RApCIk0fU4ELgM1mtqn32D865/6QvrIKR7TORegX4hXz9rQfDtyjMTRZIXSdyEAlB0e+e470bk08Ke4dtXNuo3POnHMVzrnK3j8K6RRZNGt81InkEScS7ONkBSlg3/t1cJRHf77i4HHJCZqZmGU1VWWcd9yYvcI66kSCfZysIAWsYh58//bw5rPv3653YDlEq+d5wA01k6j+5sFhy0eGQnrqyufDl5CMtkyqRoFILP2bzyTnmIuwMeVgVVdXu8bGxpQ/byEZuKQkBO+yf3/M+xyzeVl484e/RONgJSkD15XWOtLZZ2ZNzrnqSOfU9OFR0ZaUPOeV0cFhe/3fxiqkJQkDFwZra+9gaf3m4EQr8SQ1fXhUtGF73c5x4evfZMXc9ZHvgPpPFR5RHhxvrRCXfmKtK627am/SHbVHxVo6Mupi7S118PgV4SvuPX6FxllLGK0rnXsU1B4VaU2Q/iL+UD29OHyhHQg+fnJhiquTXKZ1pXOPmj48KvQW9Gd1b9AdocM34g9VtCnBnV8G76rVBFKw+ncejijx4/dZ2NIFWlfa2xTUHhYK64GjP/qvtpdwj339j7SzeT558hpouhdcN5gPjr446v9tQ3Mbix55g0DvVkLtHQGKgIOG+mnfHdCojxygoPa4RFbbW1q/OXhttKnCIY13wXsvw0/3WqVWcsmT1wT/L0Nc957HEcK6dt2WvpAO6QGcg3dXnp7GQiVV1EadA+KtttfXuZjIlOAdbwd/0CV3Nd2T1PH2jgg7BcU4Lt6joM4xMXvsK+ZB9WUQdfWQXk33prwuySC31w6bsY9LzlNQ55i4PfZn3Axz74j9JK479nnxppa64LK2STpoqD+p4+I9Cuock9BO0BXzwD8s9hNpHevcEradVhRR/s+XzZ6A3xf+LsvvM5bNnpDKCiWNFNQ5Jt5O0A3NbUxd+TxXfXkRMd8Ih9axVljnhojbafVXBLNvjXimpqqMVWdNDvueWXXWZI3yyCFalCmPNDS3sWjtG32djXOKNnKj/y6G2dfRW61LDg7uIiPeVjsi+rkRo7VUQB7QokwFYvkTW8JGhKzrmcbEr+9hYs8aonYwdnyqu2qvizdK5+o3FdJ5TkGdR0Jjqwf6srOb3SWHR/+H2h3Gu1rqoPHu6Oe1nVZBSGRz27vN7BMzezMTBUl6LNl1JlFbubQ7jDfdNyc4o5QYzZPaTqsgJHJHfS9waprrkBQoLYk+3GpdzzQ+Y3jkk9odxnvumwPvvhTzkp09wxn74DDGLnmK8373SoYKk2xIZHPbDUCMecniFbVzYg+3qg1cyG43YJNTLDgCRMP1vCVOSPc4WN51Yd/jl//yqcI6j6WsjdrM5ptZo5k1bt++PVVPK0moqSrj/Agb5Yas65nGksDltPaMxBF6Q937tnrXh8G32b86VIGdbXE6D3sc/Fv3d1nXMy3s+Mt/0f1UvkpZUDvn7nDOVTvnqkeNGpWqp5Uk3VAziVvOrox6fl3PNKZ1rqatZ2TkQO/+Gh77scI6WwYuuDSAc7Aw8BOWdV2awaIk2zTqIw/VVJUx9VuxRwMcYTuin3Q9wQkWkllxQhrga4bsdSct+U9Bnace+NHxHHlo9GnkH7mRsZ8g1nKpknotdXFD2jn4RWB+1PPxfjlL7kpkeN5DwCvAeDNrNbPL0l+WpMIz15zMrWdXUhZhIaebuuZFH64nmdVSF5zOH4Nz8J89E6LeTU/91sE88KPj01GdeICmkBeIhuY2rl6zKWxE7u/9/8QJRVuwaL2PmpqcGbdMjLnYknPw++7vRmyXPmion+brT0lndZIhmkIu1FSVcd6AESEXBq7lP3smxJgIo4WbMiLGhKNYIQ1oBbwCoaAuIKERIf2bQi4MXMtVgZ8Eh+xFCuxAh6aYp1uUCUfxQvr848ZoBbwCoaAuMKFtvfqHdWjIXrQba6cJMWn1+reupGPARKTdrpirogzDM4IhfUPNpAxVKNmmoC5QYRsN9Io2EsQAdn1I1+NXKqwHK7RLS21p3y+/hW8dyeLeiUg9zmjtGcmSwOUROw6LgFvOrlRIFxh1Jhaw6xo2c/+rH/Q9nlO0kZX+OxlqnVH/zW72Z2jtXzNRXv4Jje4I9Nv30l/CVV9ewuMJjI0+cD8fLcu17E6+UmeiRHRDzSRuPbuSEn/w2yBsinmU398l7iteWX1x5orMJ08vDg9pgEAHS4sfiftPzz9ujEK6gA3JdgGSXTVVZWHbeC2t97GucxobixdQHmH2ohkcs/PxTJeZ+1rqok4iOowdlPh9dAQibzqs9mjRHbX06b8fY6wJMT56mLryecYteYqpK5+nobkts4XmohgjZzpKDu/7uvfnM1NIC6A2aomiobmN2Q0T8Nne3x/dDj52IznCdvCRG8lNXfN41nciN86t0HCxKFxtKRZhXI1zsNy/kNrrlmehKvEStVFL0mqqynhu6Ol73VU7B44iyot2UGRQXrSDW/2/ZbG7k0Vr39Ddda/X1/0rXyw7FLdsBG7ZCKK9PfmM4dz3xZQMVye5RkEtUZ2y+AFePeRMulwRzkGXK+JL9meI9YRdV2Rwoe9ZrrO7WLhmE99a+geua9icpaqz7/V1/0pV02KG29eY0fdnYFbvdsXUBi7kiAhrsYj0p85Eien4Bff2/X0IMLS2NOJ1ZnCB71maev6OdT3TuP/VD6hvai2Y5pCG5jZq122hvSPAxuKbGFK09x10KKwdxkfuEG7qmscTPdO4JcKYdpH+1EYtSdn9628ztOPjqOedg26KeKB7etisOp8Z5x47Ou86xhqa21j+xBY+2x1gTtFGfjGkjjLbEXWhK+dg3NcPAsGJROeps1B6xWqjVlBLclrqcPU/irrdV4hz8CX78Y+By8Jm2BX7jJvOmpzTd9nXNWzmodc+pLvfz87yIXdzge9ZiuJ8YbpcEUd+fT9HlJawaNb4nP46SGopqCW1nrwG13hX3LCG8HbZgcFdloNhdd7vXgnbm3BO0UZq/b/nIL6IvlxsL+fg1UPODGtOEglRUEvqPXkNNN4NUZdyiize2srLZk/wXHA3NLexav022trDZxUmMuU+9OPVg/GnQ2oU0hKVglrSo6UuOJEjxqL3kfT/lvuM4dQGLgxrHinxF7Eii52QwRmaLXQEevY6N6doIzf672IYXwPEvYveXfINhi5+Ox1lSp5RUEt6JbApayz9vwXfdmV8r3PVXtcUGfzw2NR2vF3XsJkHX/uAngR/BOYUbeRm/78wJMIkoEgcYHN/px1yJCGDDmozOxW4DfABdzrnVsa6XkFdgJ68BpruCe5gPgj9vx0jjR7ZF0ZvaEYYy5yIREZzRPys1ZfCGTcn/wmlIA0qqM3MB/wZmAm0Aq8D5zrn3or2bxTUBWwfm0OiifTtmaoAT0Qi7dB7KTkYvvdr3UlLUgYb1McDtc65Wb2PlwI451ZE+zcKagH2ucMxEdG+bVMd4huLF1BetPcqghFpM2AZhMEG9VnAqc65y3sfXwAc65z76YDr5gPzAcaMGXP0+++/n4raJR+k+C47noHf0l+4/bi267KIO6ZEs3zI3Zznex4fPYk1d1RfpmYOGZSMBHV/uqOWqFrqggvoR1mbOV2CU7fpG/sdabRJyO/9/8QJRVsSC2grgqMvUUjLoMUK6kTW+mgDRvd7XN57TCR5FfP2NA201EH9fNLRNDKQGWETdA7mC27z/5bb+G3fsW4HPttzfVT+Epi9Wk0ckjGJrJ73OnCkmY0zs2LgHGBdesuSglAxD2rbYe7vgh1wGdZ/ZTszGFK05+9RjRitkJaMi3tH7ZzrMrOfAusJDs+72zm3Je2VSeHof5cN8M/Hwg4PThIxH1z9ZrarkAKU0DKnzrk/AH9Icy0iQT99Lfb5++bAuy9lppb+jr44859TBK1HLbnooigtb+kcXTLuJHUYSkmgCMMAAALUSURBVNYoqCV/DGxCgQHhHZqjmISiYqi5XW3SklUKaslvA8M7ZnAPeDzupOh37yIZpKCWwhLprlvE47S5rYiIxymoRUQ8TkEtIuJxCmoREY9TUIuIeFxatuIys+2A19c5HQkkuNBwXiik11tIrxUK6/Xm82v9pnNuVKQTaQnqXGBmjdGWFMxHhfR6C+m1QmG93kJ6rf2p6UNExOMU1CIiHlfIQX1HtgvIsEJ6vYX0WqGwXm8hvdY+BdtGLSKSKwr5jlpEJCcoqEVEPK6gg9rMVpnZ22bWYmaPmVlptmtKFzP7gZltMbMeM8vb4U1mdqqZbTOz/zazJdmuJ53M7G4z+8TM8n5/MDMbbWYvmNlbvd/HV2W7pkwq6KAGngEmOucqgD8DS7NcTzq9CcwFNmS7kHQxMx9wO/A94CjgXDM7KrtVpdW9wKnZLiJDuoCfOeeOAo4Drsjz/9swBR3Uzrn/cM519T58FSjPZj3p5Jzb6pzblu060mwK8N/OuXecc53Aw8D3s1xT2jjnNgCfZruOTHDOfeyc+6/ev/8N2AqUZbeqzCnooB7gUuDpbBchg1IG9N8wsZUC+mEuFGY2FqgC4uyCnD/yfocXM3sWODzCqWudc4/3XnMtwbdWD2SytlRL5LWK5DIzGw48Cix0zn2e7XoyJe+D2jn33Vjnzexi4AxghsvxQeXxXmsBaANG93tc3ntM8oCZ+QmG9APOufps15NJBd30YWanAr8A5jjndme7Hhm014EjzWycmRUD5wDanTYPmJkBdwFbnXM3Z7ueTCvooAb+GTgAeMbMNpnZ/8t2QeliZmeaWStwPPCUma3Pdk2p1tsx/FNgPcHOpjrn3JbsVpU+ZvYQ8Aow3sxazeyybNeURlOBC4DpvT+rm8zstGwXlSmaQi4i4nGFfkctIuJ5CmoREY9TUIuIeJyCWkTE4xTUIiIep6AWEfE4BbWIiMf9fzGan+7zZOrTAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "xs = np.random.normal(size=(128, 1))\n",
    "ys = xs ** 2\n",
    "\n",
    "for _ in range(1000):\n",
    "  params = update(params, xs, ys)\n",
    "\n",
    "plt.scatter(xs, ys)\n",
    "plt.scatter(xs, forward(params, xs), label='Model prediction')\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sBxOB21YNEDA"
   },
   "source": [
    "## Custom pytree nodes\n",
    "\n",
    "So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define my own container class, it will be considered a leaf, even if it has trees inside it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "CK8LN2PRFnQf"
   },
   "outputs": [],
   "source": [
    "class MyContainer:\n",
    "  \"\"\"A named container.\"\"\"\n",
    "\n",
    "  def __init__(self, name: str, a: int, b: int, c: int):\n",
    "    self.name = name\n",
    "    self.a = a\n",
    "    self.b = b\n",
    "    self.c = c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "OPGe2R7ZOXCT",
    "outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<__main__.MyContainer at 0x7fdec166ce50>,\n",
       " <__main__.MyContainer at 0x7fded89ba490>]"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_leaves([\n",
    "    MyContainer('Alice', 1, 2, 3),\n",
    "    MyContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vk4vucGXPADj"
   },
   "source": [
    "Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "vIr9_JOIOku7",
    "outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74"
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "ignored",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m jax.tree_map(lambda x: x + 1, [\n\u001b[1;32m      2\u001b[0m     \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m ])\n",
      "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf)\u001b[0m\n\u001b[1;32m    184\u001b[0m   \"\"\"\n\u001b[1;32m    185\u001b[0m   \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    188\u001b[0m def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,\n",
      "\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m jax.tree_map(lambda x: x + 1, [\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m ])\n",
      "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'MyContainer' and 'int'"
     ]
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x + 1, [\n",
    "    MyContainer('Alice', 1, 2, 3),\n",
    "    MyContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nAZ4FR2lPN51",
    "tags": [
     "raises-exception"
    ]
   },
   "source": [
    "To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "D_juQx-2OybX",
    "outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 3, 4, 5, 6]"
      ]
     },
     "execution_count": 11,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import Tuple, Iterable\n",
    "\n",
    "def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
    "  \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
    "  flat_contents = [container.a, container.b, container.c]\n",
    "\n",
    "  # we don't want the name to appear as a child, so it is auxiliary data.\n",
    "  # auxiliary data is usually a description of the structure of a node,\n",
    "  # e.g., the keys of a dict -- anything that isn't a node's children.\n",
    "  aux_data = container.name\n",
    "  return flat_contents, aux_data\n",
    "\n",
    "def unflatten_MyContainer(\n",
    "    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n",
    "  \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
    "  return MyContainer(aux_data, *flat_contents)\n",
    "\n",
    "jax.tree_util.register_pytree_node(\n",
    "    MyContainer, flatten_MyContainer, unflatten_MyContainer)\n",
    "\n",
    "jax.tree_leaves([\n",
    "    MyContainer('Alice', 1, 2, 3),\n",
    "    MyContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JgnAp7fFShEB"
   },
   "source": [
    "Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "8DNoLABtO0fr",
    "outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Alice', 1, 2, 3, 'Bob', 4, 5, 6]"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import NamedTuple, Any\n",
    "\n",
    "class MyOtherContainer(NamedTuple):\n",
    "  name: str\n",
    "  a: Any\n",
    "  b: Any\n",
    "  c: Any\n",
    "\n",
    "# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,\n",
    "# this will work out-of-the-box:\n",
    "jax.tree_leaves([\n",
    "    MyOtherContainer('Alice', 1, 2, 3),\n",
    "    MyOtherContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TVdtzJDVTZb6"
   },
   "source": [
    "Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kNsTszcEEHD0"
   },
   "source": [
    "## Common pytree gotchas and patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0ki-JDENzyL7"
   },
   "source": [
    "### Gotchas\n",
    "#### Mistaking nodes for leaves\n",
    "A common problem to look out for is accidentally introducing tree nodes instead of leaves:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "N-th4jOAGJlM",
    "outputId": "23eed14d-d383-4d88-d6f9-02bac06020df"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(DeviceArray([1., 1.], dtype=float32),\n",
       "  DeviceArray([1., 1., 1.], dtype=float32)),\n",
       " (DeviceArray([1., 1., 1.], dtype=float32),\n",
       "  DeviceArray([1., 1., 1., 1.], dtype=float32))]"
      ]
     },
     "execution_count": 13,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n",
    "\n",
    "# Try to make another tree with ones instead of zeros\n",
    "shapes = jax.tree_map(lambda x: x.shape, a_tree)\n",
    "jax.tree_map(jnp.ones, shapes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q8d4y-hfHTWh"
   },
   "source": [
    "What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n",
    "\n",
    "The solution will depend on the specifics, but there are two broadly applicable options:\n",
    "* rewrite the code to avoid the intermediate `tree_map`.\n",
    "* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n",
    "sequence a leaf."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4OKlbFlEIda-"
   },
   "source": [
    "#### Handling of None\n",
    "`jax.tree_utils` treats `None` as a node without children, not as a leaf:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "gIwlwo2MJcEC",
    "outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 14,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_leaves([None, None, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pwNz-rp1JvW4"
   },
   "source": [
    "### Patterns\n",
    "#### Transposing trees\n",
    "\n",
    "If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_multimap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "UExN7-G7qU-F",
    "outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'obs': [3, 4], 't': [1, 2]}"
      ]
     },
     "execution_count": 15,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tree_transpose(list_of_trees):\n",
    "  \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n",
    "  return jax.tree_multimap(lambda *xs: list(xs), *list_of_trees)\n",
    "\n",
    "\n",
    "# Convert a dataset from row-major to column-major:\n",
    "episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n",
    "tree_transpose(episode_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ao6R2ffm2CF4"
   },
   "source": [
    "For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "bZvVwxshz1D3",
    "outputId": "a0314dc8-4267-41e6-a763-931d40433c26"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'obs': [3, 4], 't': [1, 2]}"
      ]
     },
     "execution_count": 16,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_transpose(\n",
    "  outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
    "  inner_treedef = jax.tree_structure(episode_steps[0]),\n",
    "  pytree_to_transpose = episode_steps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KlYA2R6N2h_8"
   },
   "source": [
    "## More Information\n",
    "\n",
    "For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "jax101-pytrees",
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
